"""
Phase 9: Pre-Trend Fix — Augmented Local Projections
======================================================
Standard Jordà (2005) LP augmented with lagged dependent variable controls:

  y_{t+h} = α + β_h · x_t + γ₁·y_{t-1} + γ₂·y_{t-2} + controls_t + ε_{t+h}

Adding y_{t-1} and y_{t-2} absorbs pre-existing trends and mean reversion.
Also adds lagged rgdp_growth as a control (addresses the GDP growth pre-trend).

Re-runs OECD LPs with pre-trend check at h=-3,-2,-1.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

MAX_HORIZON = 5
PRE_HORIZONS = [-3, -2, -1]
BASE_CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']

OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHE', 'CHL', 'COL', 'CRI', 'CZE', 'DEU',
    'DNK', 'ESP', 'EST', 'FIN', 'FRA', 'GBR', 'GRC', 'HUN', 'IRL', 'ISL',
    'ISR', 'ITA', 'JPN', 'KOR', 'LTU', 'LUX', 'LVA', 'MEX', 'NLD', 'NOR',
    'NZL', 'POL', 'PRT', 'SVK', 'SVN', 'SWE', 'TUR', 'USA',
}

OUTCOMES = {
    'gross_fixed_investment_gdp': ('Investment/GDP', 'level'),
    'rgdp_growth': ('GDP Growth', 'growth'),
    'mpk_proxy': ('MPK', 'level'),
    'delta_log_kl': ('Δlog K/L', 'growth'),
}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def prepare_lags(df, y_var):
    """Add y_{t-1}, y_{t-2}, and lagged rgdp_growth to dataframe."""
    df = df.sort_values(['iso3', 'year']).copy()
    df[f'{y_var}_lag1'] = df.groupby('iso3')[y_var].shift(1)
    df[f'{y_var}_lag2'] = df.groupby('iso3')[y_var].shift(2)
    if 'rgdp_growth' in df.columns and y_var != 'rgdp_growth':
        df['rgdp_growth_lag1'] = df.groupby('iso3')['rgdp_growth'].shift(1)
    return df


def build_horizon_outcome(df, y_var, h, var_type):
    """Build outcome at horizon h."""
    df = df.sort_values(['iso3', 'year']).copy()
    if h == 0:
        df[f'{y_var}_h0'] = df[y_var]
    elif h > 0:
        if var_type == 'growth':
            df[f'{y_var}_h{h}'] = (
                df.groupby('iso3')[y_var]
                .transform(lambda s: s.rolling(window=h+1, min_periods=h+1).sum().shift(-h))
            )
        else:
            df[f'{y_var}_h{h}'] = df.groupby('iso3')[y_var].shift(-h)
    else:
        abs_h = abs(h)
        if var_type == 'growth':
            df[f'{y_var}_h{h}'] = (
                df.groupby('iso3')[y_var]
                .transform(lambda s: s.rolling(window=abs_h, min_periods=abs_h).sum().shift(1))
            )
        else:
            df[f'{y_var}_h{h}'] = df.groupby('iso3')[y_var].shift(abs_h)
    return df


def run_lp(df, y_col, x_vars, controls, key_var):
    """Run LP regression, return key_var results."""
    all_vars = [y_col] + x_vars + controls + ['iso3', 'year']
    avail = [c for c in all_vars if c in df.columns]
    sub = df[avail].dropna()
    actual_x = [v for v in x_vars + controls if v in sub.columns]

    if len(sub) < 30:
        return None

    gls = PanelGLS()
    gls.fit(sub[y_col].values, sub[actual_x].values,
            sub['iso3'].values, sub['year'].values)

    idx = actual_x.index(key_var) if key_var in actual_x else None
    if idx is None:
        return None

    return {
        'coef': gls.beta[idx],
        'se': gls.se[idx],
        'p': gls.pvalues[idx],
        'ci_lo': gls.beta[idx] - 1.96 * gls.se[idx],
        'ci_hi': gls.beta[idx] + 1.96 * gls.se[idx],
        'n_obs': gls.n_obs,
        'r_squared': gls.r_squared,
    }


def make_ascii_irf(irf_data, title, width=60):
    """ASCII impulse response plot from dict {h: result}."""
    lines = [title, '=' * len(title), '']

    all_vals = []
    for h, pt in sorted(irf_data.items()):
        if pt and not np.isnan(pt.get('coef', np.nan)):
            all_vals.extend([pt['ci_lo'], pt['ci_hi']])
    if not all_vals:
        return title + '\n  (no data)\n'

    vmin, vmax = min(all_vals), max(all_vals)
    vmin = min(vmin, 0)
    vmax = max(vmax, 0)
    span = vmax - vmin if vmax > vmin else 1

    def pos(val):
        return int((val - vmin) / span * (width - 1))

    zero_pos = pos(0)

    lines.append(f'  h  {"coef":>12s} {"SE":>8s} {"p":>6s}  {"N":>5s}  Plot')
    lines.append(f'  -  {"----":>12s} {"--":>8s} {"---":>6s}  {"---":>5s}  ' + '-' * width)

    for h, pt in sorted(irf_data.items()):
        if pt is None or np.isnan(pt.get('coef', np.nan)):
            lines.append(f' {h:2d}  {"n/a":>12s}')
            continue

        sig = stars(pt['p'])
        coef_str = f"{pt['coef']:.5f}{sig}"

        plot = [' '] * width
        plot[zero_pos] = '|'

        c_pos = min(max(0, pos(pt['coef'])), width - 1)
        lo_pos = max(0, pos(pt['ci_lo']))
        hi_pos = min(width - 1, pos(pt['ci_hi']))

        for i in range(lo_pos, hi_pos + 1):
            plot[i] = '-'
        plot[c_pos] = '*' if pt['p'] < 0.05 else 'o' if pt['p'] < 0.1 else '·'
        if not (lo_pos <= zero_pos <= hi_pos):
            plot[zero_pos] = '|'

        lines.append(f' {h:2d}  {coef_str:>12s} {pt["se"]:8.5f} {pt["p"]:6.4f}  {pt["n_obs"]:5d}  {"".join(plot)}')

    lines.append(f'     {"":>12s} {"":>8s} {"":>6s}  {"":>5s}  ' + '-' * width)
    lines.append(f'     {vmin:>10.4f}{" " * (width - 20)}{vmax:>10.4f}')
    lines.append(f'  (* p<0.05, o p<0.1, · p>0.1, | = zero)')
    lines.append('')
    return '\n'.join(lines)


def run_comparison(df, y_var, y_label, var_type, key_var='log_predicted_demo_inflows'):
    """
    Run three LP specifications side by side:
      A: Unadjusted (Phase 7 baseline)
      B: + y_{t-1}, y_{t-2}
      C: + y_{t-1}, y_{t-2}, rgdp_growth_{t-1}
    """
    df = prepare_lags(df, y_var)

    all_horizons = PRE_HORIZONS + list(range(MAX_HORIZON + 1))

    specs = {
        'A: Unadjusted': BASE_CONTROLS,
        'B: + 2 lags of y': BASE_CONTROLS + [f'{y_var}_lag1', f'{y_var}_lag2'],
    }
    # Spec C: also add lagged GDP growth (if not already the dependent var)
    if y_var != 'rgdp_growth' and 'rgdp_growth_lag1' in df.columns:
        specs['C: + 2 lags + lag growth'] = (
            BASE_CONTROLS + [f'{y_var}_lag1', f'{y_var}_lag2', 'rgdp_growth_lag1']
        )

    results = {}

    for spec_label, controls in specs.items():
        irf = {}
        for h in all_horizons:
            df_h = build_horizon_outcome(df, y_var, h, var_type)
            y_col = f'{y_var}_h{h}'
            result = run_lp(df_h, y_col, [key_var], controls, key_var)
            irf[h] = result
        results[spec_label] = irf

    return results


def main():
    print("=" * 70)
    print("PHASE 9: PRE-TREND FIX — AUGMENTED LOCAL PROJECTIONS")
    print("=" * 70)

    df = pd.read_csv(DATA / "deepening_panel.csv")
    df_oecd = df[df['iso3'].isin(OECD)].copy()
    print(f"OECD panel: {len(df_oecd)} obs, {df_oecd['iso3'].nunique()} countries")

    all_results = {}
    all_plots = []

    for y_var, (y_label, var_type) in OUTCOMES.items():
        if y_var not in df_oecd.columns:
            continue

        print(f"\n{'='*70}")
        print(f"OUTCOME: {y_label}")
        print(f"{'='*70}")

        specs = run_comparison(df_oecd, y_var, y_label, var_type)
        all_results[y_var] = specs

        for spec_label, irf in specs.items():
            plot = make_ascii_irf(irf, f'{spec_label}: Demo Inflows → {y_label} (OECD)')
            print(plot)
            all_plots.append(plot)

        # Side-by-side comparison table
        print(f"\n  Comparison: {y_label}")
        headers = list(specs.keys())
        print(f"  {'h':>3s}", end='')
        for h_label in headers:
            print(f"  {h_label:>28s}", end='')
        print()
        print(f"  {'─'*3}", end='')
        for _ in headers:
            print(f"  {'─'*28}", end='')
        print()

        all_horizons = PRE_HORIZONS + list(range(MAX_HORIZON + 1))
        for h in all_horizons:
            print(f"  {h:3d}", end='')
            for spec_label in headers:
                pt = specs[spec_label].get(h)
                if pt:
                    s = stars(pt['p'])
                    print(f"  {pt['coef']:8.4f}{s:3s} (p={pt['p']:.3f}) N={pt['n_obs']:4d}", end='')
                else:
                    print(f"  {'n/a':>28s}", end='')
            print()
        print()

    # ── Pre-trend assessment ──
    print("\n" + "=" * 70)
    print("PRE-TREND ASSESSMENT")
    print("=" * 70)

    for y_var, (y_label, _) in OUTCOMES.items():
        specs = all_results.get(y_var, {})
        if not specs:
            continue

        print(f"\n  {y_label}:")
        for spec_label, irf in specs.items():
            pre_sig = 0
            pre_total = 0
            for h in PRE_HORIZONS:
                pt = irf.get(h)
                if pt:
                    pre_total += 1
                    if pt['p'] < 0.1:
                        pre_sig += 1
            status = 'CLEAN' if pre_sig == 0 else f'{pre_sig}/{pre_total} sig'
            print(f"    {spec_label:35s}: pre-trend {status}")

            # Show the h=2 result (key J-curve horizon)
            h2 = irf.get(2)
            if h2:
                print(f"      h=2: β={h2['coef']:.4f} (p={h2['p']:.4f}) {stars(h2['p'])}")

    # ── Save ──
    print("\n--- Saving results ---")

    rows = []
    for y_var, specs in all_results.items():
        for spec_label, irf in specs.items():
            for h, pt in irf.items():
                if pt:
                    rows.append({
                        'outcome': y_var, 'spec': spec_label, 'horizon': h,
                        **pt
                    })

    pd.DataFrame(rows).to_csv(OUT_TABLES / "pretrend_fix_results.csv", index=False)

    with open(OUT_TABLES / "pretrend_fix.md", 'w') as f:
        f.write("# Table 9: Augmented Local Projections — Pre-Trend Fix\n\n")
        f.write("Adding y_{t-1} and y_{t-2} as controls absorbs pre-existing trends.\n")
        f.write("Spec C also adds lagged GDP growth.\n\n")

        for y_var, (y_label, _) in OUTCOMES.items():
            specs = all_results.get(y_var, {})
            if not specs:
                continue

            f.write(f"\n## {y_label}\n\n")
            headers = list(specs.keys())
            header_str = ' | '.join([f'{h}' for h in headers])
            f.write(f"| h | {header_str} |\n")
            f.write(f"|---|{'|'.join(['---'] * len(headers))}|\n")

            all_horizons = PRE_HORIZONS + list(range(MAX_HORIZON + 1))
            for h in all_horizons:
                cells = []
                for spec_label in headers:
                    pt = specs[spec_label].get(h)
                    if pt:
                        cells.append(f"{pt['coef']:.4f}{stars(pt['p'])} (N={pt['n_obs']})")
                    else:
                        cells.append('')
                f.write(f"| {h} | {' | '.join(cells)} |\n")
            f.write("\n")

        f.write("\n```\n")
        for plot in all_plots:
            f.write(plot + '\n')
        f.write("```\n")

        f.write("\n*p<0.1, **p<0.05, ***p<0.01. OECD subsample. PanelGLS with AR(1).*\n")

    print(f"Saved: {OUT_TABLES / 'pretrend_fix_results.csv'}")
    print(f"Saved: {OUT_TABLES / 'pretrend_fix.md'}")
    print("\nPhase 9 complete.")


if __name__ == '__main__':
    main()
